#%% 
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import seaborn as sns
import os

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default=7e-4)
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--num_samples', type=int, default=32)
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--pair_sampling', action='store_true')
args = parser.parse_args()
#%% 
df = pd.read_csv(
    "OnlineNewsPopularity/OnlineNewsPopularity.csv", 
    skipinitialspace=True
)
sr_Y = df['shares']
df_X = df.drop(
    ['url', 'timedelta', 'shares'], # non-predictive & taget features
    axis=1
)# 58 features remaining


#%% preprocessing
'''
Input features:

After dropping non-predictive input features, there are 58 input features remaining.

Data types: All the input features have numerical values, but by counting the unique values, we can tell some of them are actually binary. And some of them are related in content. So we can merge them into multinomial features encoded by integers to reduce dimension.

Feature scales: The scales(range) of these features are quite different. Standardization are needed before fitting models on them.

Outliers: If we define values greater than 5 std from mean as outliers, some features have considerable amount of outliers thus outlier removal is needed as well in preprocessing.
'''

def outlierCounts(series):
    centered = np.abs(series - series.mean())
    mask     = centered >= (5 * series.std())
    return len(series[mask])

def uniqueValueCount(series):
    return len(series.unique())

input_feats          = df_X.dtypes.reset_index()
input_feats.columns  = ['name', 'dtype']
input_feats['mean']  = df_X.mean().reset_index(drop=True)
input_feats['std']   = df_X.std().reset_index(drop=True)
input_feats['range'] = (df_X.max() - df_X.min())\
    .reset_index(drop=True)

input_feats['unique_values_count'] = df_X\
    .apply(uniqueValueCount, axis=0)\
    .reset_index(drop=True)
    
input_feats['outliers_count'] = df_X\
    .apply(outlierCounts, axis=0)\
    .reset_index(drop=True)


#%% Merge Binary Features
'''
Among those binary features, there are 6 describing the content categories of news and 7 describing the publish weekday. We can merge them and reduce the number of features to 47.
'''
def mergeFeatures(df, old_feats, new_feat):
    """ merge binary features in dataframe with int encoding.
    in: dataframe, binaryFeatureNames and multinomialFeatureName
    out: newDataframe
    """
    counter = 0
    df[new_feat] = counter
    
    for old_feat in old_feats:
        counter += 1
        df.loc[df[old_feat] == 1.0, new_feat] = counter
        del df[old_feat]
    
    return df

data_channels = [
    'data_channel_is_lifestyle',
    'data_channel_is_entertainment',
    'data_channel_is_bus',
    'data_channel_is_socmed',
    'data_channel_is_tech',
    'data_channel_is_world'
]
weekdays = [
    'weekday_is_monday',
    'weekday_is_tuesday',
    'weekday_is_wednesday',
    'weekday_is_thursday',
    'weekday_is_friday',
    'weekday_is_saturday',
    'weekday_is_sunday'
]
df = mergeFeatures(df_X, data_channels, 'data_channel')
df = mergeFeatures(df_X, weekdays, 'pub_weekday')


#%% Remove Outliers and Normalize features
import sklearn.preprocessing as prep

# remove outliers
for col in df_X.columns:
    centered = np.abs(df_X[col]-df_X[col].mean())
    mask     = centered <= (5 * df_X[col].std())
    df_X     = df_X[mask]

sr_Y = sr_Y[df_X.index]

def standarize(arr_X):
    arr_X = prep.MinMaxScaler().fit_transform(arr_X)
    return arr_X - arr_X.mean(axis=1).reshape(-1, 1)

arr_X = df_X.values
arr_X = standarize(arr_X)

#%% Binarize Y
'''
As we mentioned before, we use the median value 1400 as the threshold to binarize target feature and divide the data points into 2 classes. Because of the outlier removal, the sizes of 2 classes are not the same anymore. But they are still more or less balanced.
'''
arr_Y = prep.binarize(
    sr_Y.values.reshape(-1, 1), 
    threshold=1400 # using original median as threshold
) 
sr_Y  = pd.Series(arr_Y.ravel())

unique_items, counts = np.unique(arr_Y, return_counts=True)
#%% 
# Load and split data
X_train, X_test, Y_train, Y_test = train_test_split(
    arr_X, arr_Y, test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(
    X_train, Y_train, test_size=0.2, random_state=0)


# Data scaling
num_features = X_train.shape[1]
# feature_names = X_train.tolist()
ss = StandardScaler()
ss.fit(X_train)
X_train = ss.transform(X_train)
X_val = ss.transform(X_val)
X_test = ss.transform(X_test)

#%% Train Model
import pickle
import os.path
import sys
sys.path.append('..')
from copy import deepcopy
import torch
import torch.nn as nn
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import torch.optim as optim
import lightgbm as lgb
from lightgbm.callback import log_evaluation, early_stopping
device = torch.device('cuda')
#%% 
if os.path.isfile('news model.pkl'):
    print('Loading saved model')
    with open('news model.pkl', 'rb') as f:
        model = pickle.load(f)

else:
    # Setup
    params = {
        "max_bin": 512,
        "learning_rate": 0.05,
        "boosting_type": "gbdt",
        "objective": "binary",
        "metric": "binary_logloss",
        "num_leaves": 10,
        "verbose": -1,
        "min_data": 100,
        "boost_from_average": True
    }

    # More setup
    d_train = lgb.Dataset(X_train, label=Y_train)
    d_val = lgb.Dataset(X_val, label=Y_val)
    callbacks = [log_evaluation(period=1000), early_stopping(stopping_rounds=50)]
    # Train model
    model = lgb.train(params, d_train, 10000, valid_sets=[d_val],
                      callbacks=callbacks)
    
    # Save model
    with open('news model.pkl', 'wb') as f:
        pickle.dump(model, f)

#%% Train surrogate
import torch
import torch.nn as nn
from fastshap.utils import MaskLayer1d
from fastshap import Surrogate, KLDivLoss

# Select device
device = torch.device('cuda')

#%% 
# Check for model
if os.path.isfile('news surrogate.pt'):
    print('Loading saved surrogate model')
    surr = torch.load('news surrogate.pt').to(device)
    surrogate = Surrogate(surr, num_features)

else:
    # Create surrogate model
    surr = nn.Sequential(
        MaskLayer1d(value=0, append=True),
        nn.Linear(2 * num_features, 128),
        nn.ELU(inplace=True),
        nn.Linear(128, 128),
        nn.ELU(inplace=True),
        nn.Linear(128, 2)).to(device)

    # Set up surrogate object
    surrogate = Surrogate(surr, num_features)

    # Set up original model
    def original_model(x):
        pred = model.predict(x.cpu().numpy())
        pred = np.stack([1 - pred, pred]).T
        return torch.tensor(pred, dtype=torch.float32, device=x.device)

    # Train
    surrogate.train_original_model(
        X_train,
        X_val,
        original_model,
        batch_size=64,
        max_epochs=100,
        loss_fn=KLDivLoss(),
        validation_samples=10,
        validation_batch_size=10000,
        verbose=True)

    # Save surrogate
    surr.cpu()
    torch.save(surr, 'news surrogate.pt')
    surr.to(device)

#%% with pair sampling
from simshap.simshap_sampling import SimSHAPSampling
from models import SimSHAPTabular
# Check for model
# Create explainer model
explainer = SimSHAPTabular(in_dim=num_features, hidden_dim=128, out_dim=2).to(device)

# Set up FastSHAP object
simshap = SimSHAPSampling(explainer=explainer, imputer=surrogate, device=device)
# Train
simshap.train(
    X_train,
    X_val[:100],
    batch_size=args.batch_size,
    num_samples=args.num_samples,
    paired_sampling=args.pair_sampling,
    max_epochs=args.epochs,
    lr=args.lr,  
    bar=False,
    validation_samples=128,
    verbose=True, 
    lookback=10,
    lr_factor=0.5)
# Save explainer
explainer.cpu()
torch.save(explainer, 'news simshap ablation.pt')
explainer.to(device)